{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": []
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "https://raw.githubusercontent.com/BobXWu/TopMost/master/data/20NG.zip\n",
      "Downloading https://raw.githubusercontent.com/BobXWu/TopMost/master/data/20NG.zip to ./datasets/20NG.zip\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 11927380/11927380 [00:00<00:00, 12169265.06it/s]\n"
     ]
    }
   ],
   "source": [
    "import topmost\n",
    "from topmost.data import download_dataset\n",
    "\n",
    "device = \"cuda\" # or \"cpu\"\n",
    "dataset_dir = \"./datasets/20NG\"\n",
    "download_dataset('20NG', cache_path='./datasets')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_size:  11314\n",
      "test_size:  7532\n",
      "vocab_size:  5000\n",
      "average length: 110.543\n"
     ]
    }
   ],
   "source": [
    "########################### Neural Topic Models ####################################\n",
    "# dataset for neural topic models.\n",
    "# For combinedTM, add contextual_embed=True.\n",
    "dataset = topmost.data.BasicDataset(dataset_dir, read_labels=True, device=device, pretrained_WE=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 4/200 [00:00<00:18, 10.64it/s]2024-06-19 22:15:05,129 - TopMost - Epoch: 005 loss: 943.007\n",
      "  4%|▍         | 8/200 [00:00<00:18, 10.65it/s]2024-06-19 22:15:05,598 - TopMost - Epoch: 010 loss: 889.451\n",
      "  7%|▋         | 14/200 [00:01<00:17, 10.65it/s]2024-06-19 22:15:06,069 - TopMost - Epoch: 015 loss: 870.886\n",
      "  9%|▉         | 18/200 [00:01<00:17, 10.63it/s]2024-06-19 22:15:06,540 - TopMost - Epoch: 020 loss: 860.866\n",
      " 12%|█▏        | 24/200 [00:02<00:16, 10.64it/s]2024-06-19 22:15:07,011 - TopMost - Epoch: 025 loss: 854.163\n",
      " 14%|█▍        | 28/200 [00:02<00:16, 10.63it/s]2024-06-19 22:15:07,481 - TopMost - Epoch: 030 loss: 849.931\n",
      " 17%|█▋        | 34/200 [00:03<00:15, 10.64it/s]2024-06-19 22:15:07,952 - TopMost - Epoch: 035 loss: 846.811\n",
      " 19%|█▉        | 38/200 [00:03<00:15, 10.63it/s]2024-06-19 22:15:08,422 - TopMost - Epoch: 040 loss: 844.818\n",
      " 22%|██▏       | 44/200 [00:04<00:14, 10.64it/s]2024-06-19 22:15:08,893 - TopMost - Epoch: 045 loss: 843.003\n",
      " 24%|██▍       | 48/200 [00:04<00:14, 10.63it/s]2024-06-19 22:15:09,364 - TopMost - Epoch: 050 loss: 841.798\n",
      " 27%|██▋       | 54/200 [00:05<00:13, 10.63it/s]2024-06-19 22:15:09,835 - TopMost - Epoch: 055 loss: 840.788\n",
      " 29%|██▉       | 58/200 [00:05<00:13, 10.62it/s]2024-06-19 22:15:10,305 - TopMost - Epoch: 060 loss: 839.667\n",
      " 32%|███▏      | 64/200 [00:06<00:12, 10.62it/s]2024-06-19 22:15:10,777 - TopMost - Epoch: 065 loss: 838.765\n",
      " 34%|███▍      | 68/200 [00:06<00:12, 10.62it/s]2024-06-19 22:15:11,247 - TopMost - Epoch: 070 loss: 837.890\n",
      " 37%|███▋      | 74/200 [00:06<00:11, 10.63it/s]2024-06-19 22:15:11,719 - TopMost - Epoch: 075 loss: 837.049\n",
      " 39%|███▉      | 78/200 [00:07<00:11, 10.61it/s]2024-06-19 22:15:12,189 - TopMost - Epoch: 080 loss: 836.473\n",
      " 42%|████▏     | 84/200 [00:07<00:10, 10.64it/s]2024-06-19 22:15:12,661 - TopMost - Epoch: 085 loss: 835.645\n",
      " 44%|████▍     | 88/200 [00:08<00:10, 10.62it/s]2024-06-19 22:15:13,131 - TopMost - Epoch: 090 loss: 834.892\n",
      " 47%|████▋     | 94/200 [00:08<00:09, 10.63it/s]2024-06-19 22:15:13,602 - TopMost - Epoch: 095 loss: 834.325\n",
      " 49%|████▉     | 98/200 [00:09<00:09, 10.61it/s]2024-06-19 22:15:14,074 - TopMost - Epoch: 100 loss: 833.927\n",
      " 52%|█████▏    | 104/200 [00:09<00:09, 10.62it/s]2024-06-19 22:15:14,545 - TopMost - Epoch: 105 loss: 833.317\n",
      " 54%|█████▍    | 108/200 [00:10<00:08, 10.61it/s]2024-06-19 22:15:15,016 - TopMost - Epoch: 110 loss: 832.998\n",
      " 57%|█████▋    | 114/200 [00:10<00:08, 10.61it/s]2024-06-19 22:15:15,489 - TopMost - Epoch: 115 loss: 832.499\n",
      " 59%|█████▉    | 118/200 [00:11<00:07, 10.61it/s]2024-06-19 22:15:15,959 - TopMost - Epoch: 120 loss: 832.115\n",
      " 62%|██████▏   | 124/200 [00:11<00:07, 10.63it/s]2024-06-19 22:15:16,430 - TopMost - Epoch: 125 loss: 831.733\n",
      " 64%|██████▍   | 128/200 [00:12<00:06, 10.60it/s]2024-06-19 22:15:16,902 - TopMost - Epoch: 130 loss: 831.271\n",
      " 67%|██████▋   | 134/200 [00:12<00:06, 10.61it/s]2024-06-19 22:15:17,374 - TopMost - Epoch: 135 loss: 830.971\n",
      " 69%|██████▉   | 138/200 [00:12<00:05, 10.60it/s]2024-06-19 22:15:17,846 - TopMost - Epoch: 140 loss: 830.616\n",
      " 72%|███████▏  | 144/200 [00:13<00:05, 10.62it/s]2024-06-19 22:15:18,317 - TopMost - Epoch: 145 loss: 830.435\n",
      " 74%|███████▍  | 148/200 [00:13<00:04, 10.60it/s]2024-06-19 22:15:18,790 - TopMost - Epoch: 150 loss: 830.017\n",
      " 77%|███████▋  | 154/200 [00:14<00:04, 10.61it/s]2024-06-19 22:15:19,269 - TopMost - Epoch: 155 loss: 829.840\n",
      " 79%|███████▉  | 158/200 [00:14<00:03, 10.52it/s]2024-06-19 22:15:19,739 - TopMost - Epoch: 160 loss: 829.524\n",
      " 82%|████████▏ | 164/200 [00:15<00:03, 10.62it/s]2024-06-19 22:15:20,207 - TopMost - Epoch: 165 loss: 829.301\n",
      " 84%|████████▍ | 168/200 [00:15<00:02, 10.68it/s]2024-06-19 22:15:20,671 - TopMost - Epoch: 170 loss: 829.061\n",
      " 87%|████████▋ | 174/200 [00:16<00:02, 10.57it/s]2024-06-19 22:15:21,155 - TopMost - Epoch: 175 loss: 828.828\n",
      " 89%|████████▉ | 178/200 [00:16<00:02, 10.37it/s]2024-06-19 22:15:21,635 - TopMost - Epoch: 180 loss: 828.582\n",
      " 92%|█████████▏| 184/200 [00:17<00:01, 10.55it/s]2024-06-19 22:15:22,106 - TopMost - Epoch: 185 loss: 828.341\n",
      " 94%|█████████▍| 188/200 [00:17<00:01, 10.60it/s]2024-06-19 22:15:22,574 - TopMost - Epoch: 190 loss: 827.966\n",
      " 97%|█████████▋| 194/200 [00:18<00:00, 10.64it/s]2024-06-19 22:15:23,044 - TopMost - Epoch: 195 loss: 827.824\n",
      " 99%|█████████▉| 198/200 [00:18<00:00, 10.64it/s]2024-06-19 22:15:23,513 - TopMost - Epoch: 200 loss: 827.777\n",
      "100%|██████████| 200/200 [00:18<00:00, 10.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Topic 0: problem know problems tell stuff thing manager widget experience bbs computing switch advice learn appreciate\n",
      "Topic 1: question believe yes think agree reason say whether fact evidence true know nothing answer anything\n",
      "Topic 2: writes usa read wrote reading quote write numbers warning writing express compared loaded hole arizona\n",
      "Topic 3: thanks drive card scsi hard disk memory work speed ibm fast bus computer drivers need\n",
      "Topic 4: good think know sure opinions get really pretty guess look got maybe run lot anybody\n",
      "Topic 5: david mark system james gordon systems per corporation jon roger stephen ray total null machine\n",
      "Topic 6: lines article computer based line internet phone california current entry communications systems previous model access\n",
      "Topic 7: distribution center brian chris illinois services appreciated deal doug eric care ron greg handle douglas\n",
      "Topic 8: organization science research technology institute engineering interested canada scientific nec carnegie radar knowledge position sciences\n",
      "Topic 9: file files keywords info apr etc check directory faq newsgroup msg uunet crypto int archive\n",
      "Topic 10: max please send email call address peter ask editor fax recommend doctor asked christopher jan\n",
      "Topic 11: host nntp newsreader news australia vms semi centris neutral duo coverage centre abc summer update\n",
      "Topic 12: key chip encryption non bit using one either two number set anyone phone another least\n",
      "Topic 13: reply sun tin disclaimer apple matthew quality gif isa vax electronics math packard hewlett paper\n",
      "Topic 14: article gun right guns weapons state firearms laws arms house amendment tax illegal constitution issue\n",
      "Topic 15: also article many one see called part since even however found times much seen net\n",
      "Topic 16: windows version window image color screen display video mouse type mode size format deleted copy\n",
      "Topic 17: first world two last since second april germany week three third following multi chicago post\n",
      "Topic 18: israeli turkish jews israel jewish armenian armenians arab turks armenia soviet azerbaijan palestinian nazi turkey\n",
      "Topic 19: inc list package include box includes sale including unit excellent included shipping currently usa books\n",
      "Topic 20: radio sorry bad univ mine btw station uucp school quadra convert mass don wondering denver\n",
      "Topic 21: people said know day going like never think just come told even remember days say\n",
      "Topic 22: people government american president states war clinton country today support years office united political white\n",
      "Topic 23: use data program system also available using used output run programs application systems uses set\n",
      "Topic 24: high monitor standard controller motif ide tape low wire drive austin audio install monitors gateway\n",
      "Topic 25: bike car front just motorcycle going ride cars got speed away riding little road thing\n",
      "Topic 26: dos software mail mac ftp server graphics pub information computer unix fax email hardware dod\n",
      "Topic 27: subject name note open order source following code number gmt available message used date section\n",
      "Topic 28: writes bill advance get take keys escrow return level dave frank secure assume gary place\n",
      "Topic 29: pittsburgh toronto nhl hockey teams san league scored vancouver los rangers goals boston detroit canada\n",
      "Topic 30: new national information old book anti news service york secret details corp laboratory cambridge recently\n",
      "Topic 31: possible made done try work something probably keith smith someone comments andrew tried make works\n",
      "Topic 32: one like time john clipper robert heard ever test years paul south north originator stop\n",
      "Topic 33: game year team play games best players like season player points time baseball lost win\n",
      "Topic 34: god jesus people christian faith christians bible christ church life religion believe one christianity many\n",
      "Topic 35: better now end course good running start year still around get probably way just going\n",
      "Topic 36: much buy car price cost board sell pay money new insurance sale bought cheap market\n",
      "Topic 37: one point way example different mean just means time best case make even difference given\n",
      "Topic 38: university state michael department dept college georgia body virginia tim ohio student daniel stanford brain\n",
      "Topic 39: organization etc group black pro internet figure lab cops eisa mask voice jeff ultra bitnet\n",
      "Topic 40: subject case jim discussion response condition cleveland legal cramer applied hot perry walker connected houston\n",
      "Topic 41: find just give long make far know yet seems free something true less anyone life\n",
      "Topic 42: lines like use help need want power anyone without used way system make work able\n",
      "Topic 43: space nasa years least air earth launch long orbit moon time high flight field year\n",
      "Topic 44: lines real summary line rather ones light ideas product cable original base thus fixed upon\n",
      "Topic 45: posting post steve mail mike posted texas bob tom general pin posts postings tech pins\n",
      "Topic 46: get now put says just got home already ago back keep couple say try help\n",
      "Topic 47: law public control government rights police security fbi court private local crime state enforcement federal\n",
      "Topic 48: health medical cases patients less food study certain drugs care studies disease treatment drug effective\n",
      "Topic 49: one time back left another see right side two went just first came started looking\n"
     ]
    }
   ],
   "source": [
    "# create a model\n",
    "# model = topmost.ProdLDA(dataset.vocab_size)\n",
    "model = topmost.ETM(dataset.vocab_size, pretrained_WE=dataset.pretrained_WE)\n",
    "# model = topmost.DecTM(dataset.vocab_size)\n",
    "# model = topmost.TSCTM(dataset.vocab_size)\n",
    "# model = topmost.CombinedTM(dataset.vocab_size, dataset.contextual_embed_size)\n",
    "# model = topmost.NSTM(dataset.vocab_size, pretrained_WE=dataset.pretrained_WE)\n",
    "# model = topmost.ECRTM(dataset.vocab_size, pretrained_WE=dataset.pretrained_WE)\n",
    "model = model.to(device)\n",
    "\n",
    "# create a trainer\n",
    "trainer = topmost.BasicTrainer(model, dataset, verbose=True)\n",
    "\n",
    "# train the model\n",
    "top_words, train_theta = trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TD: 0.73067\n",
      "{'Purity': 0.38382899628252787, 'NMI': 0.3483013957003741}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'acc': 0.5326606479022836, 'macro-F1': 0.5171831529186143}\n"
     ]
    }
   ],
   "source": [
    "########################### Evaluate ####################################\n",
    "from topmost import eva\n",
    "\n",
    "# get theta (doc-topic distributions)\n",
    "train_theta, test_theta = trainer.export_theta()\n",
    "\n",
    "# topic coherence\n",
    "TC = eva._coherence(dataset.train_texts, dataset.vocab, top_words)\n",
    "\n",
    "# topic diversity\n",
    "TD = eva._diversity(top_words)\n",
    "print(f\"TD: {TD:.5f}\")\n",
    "\n",
    "# clustering\n",
    "results = eva._clustering(test_theta, dataset.test_labels)\n",
    "print(results)\n",
    "\n",
    "# classification\n",
    "results = eva._cls(train_theta, test_theta, dataset.train_labels, dataset.test_labels)\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "parse texts: 100%|██████████| 2/2 [00:00<00:00, 20068.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[43 16]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "########################### test new documents ####################################\n",
    "import torch\n",
    "from topmost import Preprocess\n",
    "\n",
    "preprocess = Preprocess()\n",
    "\n",
    "new_docs = [\n",
    "    \"This is a new document about space, including words like space, satellite, launch, orbit.\",\n",
    "    \"This is a new document about Microsoft Windows, including words like windows, files, dos.\"\n",
    "]\n",
    "\n",
    "parsed_new_docs, new_bow = preprocess.parse(new_docs, vocab=dataset.vocab)\n",
    "new_theta = trainer.test(torch.as_tensor(new_bow, device=device).float())\n",
    "\n",
    "print(new_theta.argmax(1))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch1.13py3.8",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
